import jieba
import evaluate
import numpy as np

from rouge import Rouge

from scorers.scorer import Scorer


class AutomaticScorer(Scorer):

    def __init__(self):
        super(AutomaticScorer, self).__init__()

    def calculate_bleu(self, language, predictions, references):
        if language == 'en':
            references = [[reference] for reference in references]
            tokenize = '13a'
        elif language == 'zh':
            predictions = [' '.join(list(jieba.cut(prediction.replace(' ', ''), cut_all=False))) for prediction in predictions]
            references = [[' '.join(list(jieba.cut(reference.replace(' ', ''), cut_all=False)))] for reference in references]
            tokenize = 'zh'
        else:
            raise Exception(f'### [Unknown language]')

        # https://huggingface.co/spaces/evaluate-metric/sacrebleu
        bleu = evaluate.load('sacrebleu')
        try:
            score = bleu.compute(predictions=predictions, references=references, tokenize=tokenize)['score']
        except Exception as e:
            print(f'### [Failed to calculate bleu {language}]: ' + e.__str__())
            return 0.0

        return score

    def calculate_rouge(self, language, predictions, references):
        if language == 'en':
            pass
        elif language == 'zh':
            predictions = [' '.join(list(jieba.cut(prediction.replace(' ', ''), cut_all=False))) for prediction in predictions]
            references = [' '.join(list(jieba.cut(reference.replace(' ', ''), cut_all=False))) for reference in references]
        else:
            raise Exception(f'### [Unknown language]')

        rouge = Rouge()
        scores = []
        for (prediction, reference) in zip(predictions, references):
            try:
                score = rouge.get_scores(hyps=[prediction], refs=[reference], avg=True)['rouge-l']['f']
            except Exception as e:
                print(f'### [Failed to calculate rouge {language}]: ' + e.__str__())
                score = 0.0
            scores.append(score)

        return np.mean(scores) * 100

    def calculate_bertscore(self, language, predictions, references):
        if language == 'en':
            lang = 'en'
        elif language == 'zh':
            predictions = [prediction.replace(' ', '') for prediction in predictions]
            references = [reference.replace(' ', '') for reference in references]
            lang = 'zh'
        else:
            raise Exception(f'### [Unknown language]')

        # https://huggingface.co/spaces/evaluate-metric/bertscore
        bertscore = evaluate.load('bertscore')
        try:
            scores = bertscore.compute(predictions=predictions, references=references, lang=lang)['f1']
        except Exception as e:
            print(f'### [Failed to calculate bertscore {language}]: ' + e.__str__())
            return 0.0

        return np.mean(scores) * 100
